open Ast
open Lexing
open Parsing
open Utils

module NameMap = Map.Make(String)
module GenericMap = Map.Make(Char)
module KeywordsSet = Set.Make(String)

let type_variable = (ref ['A'; 'A'; 'A']);;

let js_keywords = ["break"; "case"; "class"; "catch"; "const"; "continue";
                   "debugger"; "default"; "delete"; "do"; "else";
                   "export"; "extends"; "finally"; "for"; "function"; "if";
                   "import"; "in"; "instanceof"; "new"; "return"; "super";
                   "switch"; "this"; "throw"; "try"; "typeof"; "var"; "void";
                   "while"; "with"; "yield"; "val"; ];;

let spy_lib = ["print"; "print_num"; "print_str"; "print_bool";
               "print_list"; "print_map"; "to_str"; "hd"; "empty";
               "tl"; "contains"; "del"; "keys"; "len"; "filter"; "map"; 
               "foldl"; "count"; ];;

let all_keywords = List.fold_left (fun acc x -> KeywordsSet.add x acc)
    KeywordsSet.empty (spy_lib @ js_keywords);;

let spy_keywords = List.fold_left (fun acc x -> KeywordsSet.add x acc)
    KeywordsSet.empty spy_lib;;

let get_new_type () =
  let rec helper cs =
    let ys = match List.rev cs with
        [] -> ['A']
      | 'Z' :: xs -> 'A' :: List.rev (helper (List.rev xs))
      | x :: xs -> (Char.chr ((Char.code x) + 1)) :: xs
    in List.rev ys
  in
  let curr_type_var = !type_variable in
  type_variable := helper (curr_type_var);
  T(String.concat "" (List.map Char.escaped curr_type_var))
;;

let merge_env env =
  let locals, globals = env in
  let merged_globals = NameMap.merge (fun k k1 k2 -> match k1, k2 with
        Some k1, Some k2 -> Some k1
      | None, k2 -> k2
      | k1, None -> k1)
      locals globals in
  NameMap.empty, merged_globals
;;

let rec annotate_expr e env =
  match e with
    VoidLit       -> AVoidLit(TVoid), env
  | NumLit(n)     -> ANumLit(n, TNum), env
  | BoolLit(b)    -> ABoolLit(b, TBool), env
  | StringLit(s)  -> AStringLit(s, TString), env

  | Binop(e1, op, e2) ->
    let ae1 = fst(annotate_expr e1 env)
    and ae2 = fst(annotate_expr e2 env)
    and new_type = get_new_type () in
    ABinop(ae1, op, ae2, new_type), env

  | Unop(op, e) -> 
    let ae = fst(annotate_expr e env) 
    and new_type = get_new_type () in
    AUnop(op, ae, new_type), env

  | Val(id) ->
    let locals, globals = env in
    let typ = if NameMap.mem id locals
              then NameMap.find id locals
              else if NameMap.mem id globals
              then NameMap.find id globals
              else raise (Failure("Error: value " ^ id ^ " was used before it was defined")) in
    AVal(id, typ), env

  | FunLit(formals, e) -> begin
      List.iter (fun k -> if KeywordsSet.mem k all_keywords
                          then raise (Failure("Error: Cannot redefine keyword " ^ k)) 
                          else ()) formals;
      let annotated_args = List.map (fun formal -> (formal, get_new_type ())) formals in
      let locals, globals = merge_env env in
      let new_locals = List.fold_left (fun map (it, at) -> if NameMap.mem it map
                                                           then raise (Failure("Error: " ^ it ^ " cannot be redefined in the current scope"))
                                                           else NameMap.add it at map) locals annotated_args in
      let new_env = (new_locals, globals) in
      let ae, _ = (match e with
            Block(es) -> begin
                let aes, _ = List.fold_left (fun (aes, env) e -> let ae, env = annotate_expr e env in (ae :: aes, env)) ([], new_env) es
                in ABlock(List.rev aes, get_new_type()), new_env
              end
            | _ -> raise (Failure("Unreachable state: FunLit"))) in 
      let arg_types = List.map snd annotated_args in 
      AFunLit(formals, ae, TFun(arg_types, get_new_type ())), env
    end

  | Call(fn, args) -> begin
      let afn, _  = annotate_expr fn env in
      let aargs = List.map (fun arg -> fst (annotate_expr arg env)) args in
      ACall(afn, aargs, get_new_type ()), env
    end

  | Assign(id, e) -> begin
      let locals, globals = env in
      if NameMap.mem id locals
      then raise (Failure("Error: " ^ id ^ " cannot be redefined in the current scope"))
      else if KeywordsSet.mem id all_keywords
      then raise (Failure("Error: Cannot redefine keyword " ^ id)) 
      else let t = get_new_type () in
        let new_locals = NameMap.add id t locals in
        let ae, _ = match e with
            FunLit(_) -> annotate_expr e (new_locals, globals)
          | _ -> annotate_expr e env in 
        AAssign(id, t, ae, TVoid), (new_locals, globals)
    end

  | ListLit(es) ->
    let aes = List.map (fun e -> fst (annotate_expr e env)) es in
    AListLit(aes, TList(get_new_type ())), env

  | DictLit(kvpairs) ->
    let apairs = List.map (fun (k, v) -> fst (annotate_expr k env), fst (annotate_expr v env)) kvpairs in
    ADictLit(apairs, TDict(get_new_type (), get_new_type ())), env

  | If(pepairs, e) ->
    let apairs = List.map (fun (p, e) -> fst (annotate_expr p env), fst (annotate_expr e env)) pepairs 
    and ae = fst (annotate_expr e env) in 
    AIf(apairs, ae, get_new_type ()), env

  | Block(es) -> begin
      let new_env = merge_env env in
      let aes, new_env = List.fold_left (fun (aes, env) e -> let ae, env = annotate_expr e env in (ae :: aes, env)) ([], new_env) es in
      ABlock(List.rev aes, get_new_type ()), env
    end

  | Element(id, e) -> begin
      let locals, globals = env in
      let typ = if NameMap.mem id locals
                then NameMap.find id locals
                else if NameMap.mem id globals
                then NameMap.find id globals
                else raise (Failure("Error: value " ^ id ^ " was used before it was defined")) in
      let t = (match typ with 
          TList(t) -> t
        | TDict(t1, t2) -> t2
        | _ -> get_new_type ()) in
      let ae, _ = annotate_expr e env in
      AElement(id, ae, t), env
    end
  | _ -> raise (Failure("Error: unexpected expression found"))
;;

let type_of = function
    AVoidLit(t)           -> t
  | ANumLit(_, t)         -> t
  | ABoolLit(_, t)        -> t
  | AStringLit(_, t)      -> t
  | ABinop(_, _, _, t)    -> t
  | AUnop(_, _, t)        -> t
  | AListLit(_, t)        -> t
  | ADictLit(_, t)        -> t
  | ABlock(_, t)          -> t
  | AAssign(_, _, _, t)   -> t
  | AVal(_, t)            -> t
  | AIf(_, _, t)          -> t
  | ACall(_, _, t)        -> t
  | AFunLit(_, _, t)      -> t
  | AElement(_, _, t)     -> t
;;

let rec get_constraints ae =
  match ae with
    AVoidLit(_) | ANumLit(_) | ABoolLit(_) | AStringLit(_) | AVal(_) -> []

  | ABinop(ae1, op, ae2, t) ->
    let et1 = type_of ae1 and et2 = type_of ae2 in
    let opc = match op with
        Add | Sub | Mult | Div | Mod  -> [(et1, TNum); (et2, TNum); (t, TNum)]
      | Caret                         -> [(et1, TString); (et2, TString); (t, TString)]
      | And | Or                      -> [(et1, TBool); (et2, TBool); (t, TBool)]
      | Lt | Lte | Gt | Gte           -> [(et1, et2); (t, TBool)]
      | Eq | Neq                      -> [(et1, et2); (t, TBool)] 
      | Cons -> (match et2 with
            TList(x) -> [(et1, x); (t, et2)]
          | T(_)     -> [(et2, TList(et1)); (t, TList(et1))]
          | _ -> raise (Failure("Type error: Lists can only contain one type")))
      | Append -> (match et1, et2 with
            TList(_), TList(_) -> [(et1, et2); (t, et2)] 
          | T(_), TList(_)     -> [(et1, et2); (t, et2)] 
          | TList(_), T(_)     -> [(et2, et1); (t, et1)] 
          | _, _ -> raise (Failure("Type error: Lists can only contain one type"))) in
    (get_constraints ae1) @ (get_constraints ae2) @ opc

  | AUnop(op, ae, t) ->
    let et = type_of ae in
    let opc = (match op with
          Not -> [(et, TBool); (t, TBool)]
        | Neg -> [(et, TNum); (t, TNum)]) in
    (get_constraints ae) @ opc

  | AIf(apairs, ae, t) -> 
    let plist = List.map fst apairs in
    let elist = List.map snd apairs in
    let p_conts = List.map (fun p -> (type_of p, TBool)) plist in
    let e_conts = List.map (fun e -> (type_of e, type_of ae)) elist in
    (t, type_of ae) :: (List.flatten (List.map get_constraints plist)) @ (List.flatten (List.map get_constraints elist)) @ (get_constraints ae) @ p_conts @ e_conts 

  | AListLit(aes, t) ->
    let list_type = match t with
        TList(x) -> x
      | _ -> raise (Failure("Unreachable state: ListLit")) in
    let elem_conts = List.map (fun ae -> (list_type, type_of ae)) aes in
    (List.flatten (List.map get_constraints aes)) @ elem_conts

  | ADictLit(kvpairs, t) ->
    let kt, vt = match t with
        TDict(kt, vt) -> kt, vt
      | _ -> raise (Failure("Unreachable state: DictLit")) in 
    let klist = List.map fst kvpairs in
    let vlist = List.map snd kvpairs in
    let k_conts = List.map (fun k -> (kt, type_of k)) klist in
    let v_conts = List.map (fun v -> (vt, type_of v)) vlist in
    [(t, TDict(kt, vt))] @ (List.flatten (List.map get_constraints klist)) @ (List.flatten (List.map get_constraints vlist)) @ k_conts @ v_conts

  | ABlock(aes, t) ->
    let last_type = (match List.hd (List.rev aes) with
          AAssign(id, _, _, _) -> raise (Failure("Type error: Assignment expressions cannot be returned"))
        | ae -> type_of ae) in
    (t, last_type) :: (List.flatten (List.map get_constraints aes))

  | AAssign(_, t, ae, _) -> (t, type_of ae) :: (get_constraints ae)

  | AFunLit(_, ae, t) -> (match t with
        TFun(_, ret_type) -> (type_of ae, ret_type) :: (get_constraints ae)
      | _ -> raise (Failure("Unreachable state: FunLit")))

  | ACall(afn, aargs, t) ->
    let typ_afn = (match afn with
          AVal(_) -> type_of afn
        | _ -> raise (Failure("Unreachable state: Call"))) in
    let sign_conts = (match typ_afn with
          TFun(arg_types, ret_type) -> begin
            let l1 = List.length aargs and l2 = List.length arg_types in
            if l1 <> l2
            then raise (Failure("Error: Expected number of argument(s): " ^ (string_of_int l2) ^ ", got " ^ (string_of_int l1) ^ " instead."))
            else let arg_conts = List.map2 (fun ft at -> (type_of at, ft)) arg_types aargs in
            (ret_type, t) :: arg_conts
          end
        | T(_) -> [(typ_afn, TFun(List.map type_of aargs, t))]
        | _ -> let st = Utils.string_of_type typ_afn in 
               let ft = TFun(List.map type_of aargs, t) in
               let sft = Utils.string_of_type ft in
               raise (Failure("Type error: expected value of type " ^ st ^ ", got type " ^ sft))) in
    (get_constraints afn) @ (List.flatten (List.map get_constraints aargs)) @ sign_conts
  | AElement(_, ae, t) -> (get_constraints ae)
;;

let rec substitute pt ch typ =
  match typ with
    TNum | TBool | TString | TVoid | TAny -> typ
  | T(c) -> if c = ch then pt else typ
  | TFun(t1, t2) -> TFun(List.map (substitute pt ch) t1, substitute pt ch t2)
  | TList(t) -> TList(substitute pt ch t)
  | TDict(kt, vt) -> TDict(substitute pt ch kt, substitute pt ch vt)
;;

let apply subs t =
    List.fold_right (fun (ch, pt) typ -> substitute pt ch typ) subs t
;;

let rec resolve_type ch t =
  match t with
    T(c) when ch = c -> false
  | TList(tlist) -> resolve_type ch tlist
  | TDict(kt, vt) -> (resolve_type ch kt) && (resolve_type ch vt)
  | TFun(argst, ret_t) -> let res = List.map (resolve_type ch) (ret_t :: argst) in List.fold_left ( && ) true res
  | _ -> true
;;

let rec unify = function
    [] -> []
  | (t1, t2) :: cs ->
    let subs2 = unify cs in
    let subs1 = unify_one (apply subs2 t1) (apply subs2 t2) in
    subs1 @ subs2

and unify_one t1 t2 =
  match t1, t2 with
    TNum, TNum | TBool, TBool | TString, TString | TVoid, TVoid -> []
  | T(x), z | z, T(x) ->
      if (t1 = t2 || resolve_type x z)
      then [(x, z)]
      else raise (Failure("Type error: expected value of type " ^ (Utils.string_of_type t2) ^ ", got type " ^ (Utils.string_of_type t1)))
  | TList(t1), TList(t2) -> unify_one t1 t2
  | TDict(kt1, vt1), TDict(kt2, vt2) ->
          let _ = (match kt1 with
              TNum | TBool | TString | T(_) -> ()
            | _ -> raise (Failure("Type Error: Cannot have key of type " ^ (Utils.string_of_type kt1) ^ " in a Dict."))) in
          unify [(kt1, kt2) ; (vt1, vt2)]
  | TFun(a, b), TFun(x, y) ->
      let l1 = List.length a and l2 = List.length x in
      if l1 = l2
      then unify ((List.combine a x) @ [(b, y)])
      else raise (Failure("Error: Expected number of argument(s): " ^ (string_of_int l1) ^ ", got " ^ (string_of_int l2) ^ " instead."))
  | _ -> raise (Failure("Type error: expected value of type " ^ (Utils.string_of_type t2) ^ ", got type " ^ (Utils.string_of_type t1)))
;;

let rec apply_expr subs ae =
  match ae with
    ABoolLit(b, t)           -> ABoolLit(b, apply subs t)
  | ANumLit(n, t)            -> ANumLit(n, apply subs t)
  | AStringLit(s, t)         -> AStringLit(s, apply subs t)
  | AVoidLit(t)              -> AVoidLit(apply subs t)
  | AVal(s, t)               -> AVal(s, apply subs t)
  | AAssign(id, t, ae, _)    -> AAssign(id, apply subs t, apply_expr subs ae, TVoid)
  | ABinop(ae1, op, ae2, t)  -> ABinop(apply_expr subs ae1, op, apply_expr subs ae2, apply subs t)
  | AUnop(op, ae, t)         -> AUnop(op, apply_expr subs ae, apply subs t)
  | AListLit(aes, t)         -> AListLit(List.map (apply_expr subs) aes, apply subs t)
  | ADictLit(kvpairs, t)     -> ADictLit(List.map (fun (k, v) -> (apply_expr subs k, apply_expr subs v)) kvpairs, apply subs t)
  | AIf(apairs, ae, t)       -> AIf(List.map (fun (p, e) -> (apply_expr subs p, apply_expr subs e)) apairs, apply_expr subs ae, apply subs t)
  | ABlock(aes, t)           -> ABlock(List.map (apply_expr subs) aes, apply subs t)
  | AFunLit(ids, ae, t)      -> AFunLit(ids, apply_expr subs ae, apply subs t)
  | ACall(afn, aargs, t)     -> ACall(apply_expr subs afn, List.map (apply_expr subs) aargs, apply subs t)
  | AElement(id, ae, t)      -> AElement(id, apply_expr subs ae, apply subs t)
;;

let infer expr env =
  let aexpr, env = annotate_expr expr env in
  let constraints = get_constraints aexpr in
  let subs = unify constraints in
  let inferred_expr = apply_expr subs aexpr in
  inferred_expr, env
;;

let type_check program = 
  let built_in_functions = Library.built_in_functions in
  let env = (NameMap.empty, built_in_functions) in
  let infer_helper (acc, env) expr = 
    let inferred_expr, env = try infer expr env 
                             with e -> raise e in
    let inferred_expr, env = (match inferred_expr with
        AAssign(id, t, ae, _) ->
          let subs = [] in
          let locals, globals = env and aet = type_of ae in
          let locals = NameMap.add id aet locals in
          let ret_ae = AAssign(id, apply subs t, ae, TVoid) in
          ret_ae, (locals, globals)
      | _ -> inferred_expr, env) in 
    (inferred_expr :: acc, env) in
  let inferred_program, _ = List.fold_left infer_helper ([], env) program in
  List.rev inferred_program
;;
